from typing import Tuple

import torch
import torch.nn.functional as F
from torch import nn

from auxiliary.settings import DEVICE
from classes.modules.common.alexnet.AlexNetLoader import AlexNetLoader

class RCCNet(nn.Module):
    def __init__(self, hidden_size: int = 128, kernel_size: int = 5):
        super().__init__()

        #self.device = DEVICE
        self.device = torch.device("cuda:1")

        s1 = AlexNetLoader().load(pretrained=True)
        self.alexnet1_1_A = nn.Sequential(*list(s1.children())[0][:12])

        s2 = AlexNetLoader().load(pretrained=True)
        self.alexnet1_1_B = nn.Sequential(*list(s2.children())[0][:12])


        self.maxpool=nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)

        self.lstm_A = nn.LSTMCell(256,128)
        self.lstm_B = nn.LSTMCell(256,128)

        self.fc = nn.Sequential(
                  nn.Linear(256,64),
                  nn.Sigmoid(),
                  nn.Linear(64,3),
                  nn.Sigmoid()
        )

    def forward(self, a: torch.Tensor, b: torch.Tensor)-> torch.Tensor:
        nbatch, nstep, nchan, h, w = a.shape
        a = a.view(nbatch * nstep, nchan, h, w)
        b = b.view(nbatch * nstep, nchan, h, w)
        a = self.alexnet1_1_A(a)
        b = self.alexnet1_1_B(b)
        a = torch.mean(a,dim=(2,3))
        b = torch.mean(b,dim=(2,3))
        a=a.view(nbatch,nstep,-1)
        b=b.view(nbatch,nstep,-1)

        hidden_state_A=torch.zeros((nbatch, 128)).to(DEVICE)
        cell_state_A  =torch.zeros((nbatch, 128)).to(DEVICE)
        hidden_state_B=torch.zeros((nbatch, 128)).to(DEVICE)
        cell_state_B  =torch.zeros((nbatch, 128)).to(DEVICE)

        #a,(h_a,_)=self.lstm_A(a)
        #b,(h_b,_)=self.lstm_B(b)
        for t in range(a.shape[1]):
            hidden_state_A, cell_state_A=self.lstm_A(a[:,t,:], (hidden_state_A, cell_state_A))
            hidden_state_B, cell_state_B=self.lstm_B(b[:,t,:], (hidden_state_B, cell_state_B))

        #c=torch.cat((h_a,h_b),2) #(1, batch, channel)
        #c=c[-1,:,:]

        c=torch.cat((hidden_state_A, hidden_state_B),1)
        c=self.fc(c)
        return c


